from __future__ import absolute_import, division, print_function
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
from itertools import product
np.set_printoptions(precision=2)
import time
import itertools


'''
    Find the feasible location of users
    This function finds the feasible location of INITIALIZATION
    In this function, the location of RX is randomly chosen such that its distance with
    TX is within Dist_TX_RX
'''


def Feasible_Loc_Init(Cur_loc, Size_area, Dist_TX_RX):
    temp_dist = Dist_TX_RX * (np.random.rand(1, 2) - 0.5)
    temp_chan = Cur_loc + temp_dist
    while (np.max(abs(temp_chan)) > Size_area / 2) | (np.linalg.norm(temp_dist) > Dist_TX_RX):
        temp_dist = Dist_TX_RX * (np.random.rand(1, 2) - 0.5)
        temp_chan = Cur_loc + temp_dist
    return temp_chan


'''
    Find the feasible location of users
    This function finds the feasible location of LOCATION UPDATE
    In this function, the location of RX is randomly chosen such that its distance with
    TX is within Dist_TX_RX AND the updated location is Delta_mov away from current RX location

'''


def Feasible_Loc_Update(Cur_RX_loc, Cur_TX_loc, Size_area, Dist_TX_RX, Delta_mov):
    temp_chan = 0
    temp_dist = 2 * Dist_TX_RX

    while (np.max(abs(temp_chan)) > Size_area / 2) | (np.linalg.norm(temp_dist) > Dist_TX_RX):
        temp_dir = np.random.rand()
        temp_dist_delta = [Delta_mov * np.cos(2 * np.pi * temp_dir), Delta_mov * np.sin(2 * np.pi * temp_dir)]
        temp_chan = Cur_RX_loc + temp_dist_delta
        temp_dist = Cur_TX_loc - temp_chan

    return temp_chan


'''
    Initialization of location information.
    It will return the location of one PU and SUs.
    The Users will be allocated to 2D area whose range is -Size_area/2 ~ Size_area/2.
    The distance between RX and TX for the same transmit pair is limited to "Dist_TX_RX"

    Input: Size_area, Dist_TX_RX, Num_D2D, Num_Ch
    -> Number of channel is same with the number of CUE

'''


def loc_init(Size_area, Dist_TX_RX, Num_D2D, Num_Ch):
    tx_loc = Size_area * (np.random.rand(Num_D2D, 2) - 0.5)
    rx_loc = np.zeros((Num_D2D + 1, 2))
    for i in range(Num_D2D):
        temp_chan = Feasible_Loc_Init(tx_loc[i, :], Size_area, Dist_TX_RX)
        rx_loc[i, :] = temp_chan
    tx_loc_CUE = Size_area * (np.random.rand(Num_Ch, 2) - 0.5)

    return rx_loc, tx_loc, tx_loc_CUE


'''
    Update location of users.
    Update is conducted by considering the Delta_mov which is the amount of distance that users moves

    Input: Size_area, Dist_TX_RX, Num_D2D, Num_Ch
    -> Number of channel is same with the number of CUE

'''


def loc_update(Size_area, Dist_TX_RX, rx_loc, tx_loc, tx_loc_CUE, Delta_mov):
    tx_loc_update = tx_loc
    rx_loc_update = rx_loc
    tx_loc_CUE_update = tx_loc_CUE

    Num_D2D = np.shape(tx_loc)[0]
    Num_CH = np.shape(tx_loc_CUE)[0]

    ## Determine the location of D2D users
    for i in range(Num_D2D):
        ## Use 2*size_area to deactivate the second condition
        tx_loc_update[i, :] = Feasible_Loc_Update(tx_loc[i, :], tx_loc_update[i, :], Size_area, 2 * Size_area,
                                                  Delta_mov)
        rx_loc_update[i, :] = Feasible_Loc_Update(rx_loc[i, :], tx_loc_update[i, :], Size_area, Dist_TX_RX, Delta_mov)

    ## Determine the location of CUE Users
    for i in range(Num_CH):
        tx_loc_CUE_update[i, :] = Feasible_Loc_Update(tx_loc_CUE[i, :], tx_loc_CUE[i, :], Size_area, 2 * Size_area,
                                                      Delta_mov)

    return rx_loc_update, tx_loc_update, tx_loc_CUE_update


'''
    Determine the channel gain  (UPLINK)

    --------------------------------------------
    ** The basic setting
      : Pathloss exponent -> 3.8
      : Pathloss constant -> 34.5
      : Moving speed      -> 3km/h
      : Power update      -> 100ms
      : Reset location    -> every 100 samples
    --------------------------------------------


    Each channel is time varying according to speed

    Location of all users are initialized every 1000 samples.

    The location of CUE for each band is different

    The output looks like as follows:

    output[Sample][Channel][Users][Users]


    Example for 2 D2D and 1 CUE channel   :
    [
        h_{D2D_1 -> D2D_1}        h_{D2D_1->D2D_2}          h_{D2D_1->BS}
        h_{D2D_2 -> D2D_1}        h_{D2D_2->D2D_2}          h_{D2D_2->CUE}
        h_{CUE -> D2D_1}          h_{CUE->D2D_2}            h_{CUE->BS}
    ]


    Accordingly, output[0] returns the first sample
    output[0][0] return the (N+1) X (N+1) channel gain for first band

'''


def ch_gen(Size_area, D2D_dist, Num_D2D, Num_Ch, Num_samples, Delta_mov=0.0833, PL_alpha=38., PL_const=34.5):
    ch_w_fading = []

    ## Perform initialization just once and the rest channel is generated by moving users
    rx_loc, tx_loc, tx_loc_CUE = loc_init(Size_area, D2D_dist, Num_D2D, Num_Ch)

    ## Calculate the
    for i in range(Num_samples):

        rx_loc, tx_loc, tx_loc_CUE = loc_init(Size_area, D2D_dist, Num_D2D, Num_Ch)

        ch_w_temp_band = []
        for j in range(Num_Ch):
            tx_loc_with_CUE = np.vstack((tx_loc, tx_loc_CUE[j]))
            ## generate distance_vector
            dist_vec = rx_loc.reshape(Num_D2D + 1, 1, 2) - tx_loc_with_CUE
            dist_vec = np.linalg.norm(dist_vec, axis=2)
            dist_vec = np.maximum(dist_vec, 5)

            # find path loss // shadowing is not considered
            pu_ch_gain_db = - PL_const - PL_alpha * np.log10(dist_vec)
            pu_ch_gain = 10 ** (pu_ch_gain_db / 10)

            multi_fading = 0.5 * np.random.randn(Num_D2D + 1, Num_D2D + 1) ** 2 + 0.5 * np.random.randn(Num_D2D + 1,
                                                                                                        Num_D2D + 1) ** 2
            final_ch = np.maximum(pu_ch_gain * multi_fading, np.exp(-30))
            ch_w_temp_band.append(np.transpose(final_ch))

        ch_w_fading.append(ch_w_temp_band)
    return ch_w_fading


'''
Find the optimal value for one sample with one channel

The shape of input channel is as follows:

    -> channel[users][users]

Note that i

    Example for 2 D2D and 1 CUE channel   :
    [
        h_{D2D_1 -> D2D_1}        h_{D2D_1->D2D_2}          h_{D2D_1->BS}
        h_{D2D_2 -> D2D_1}        h_{D2D_2->D2D_2}          h_{D2D_2->BS}
        h_{CUE -> D2D_1}          h_{CUE->D2D_2}            h_{CUE->BS}
    ]


'''






def cal_SINR_one_sample_one_channel(channel, tx_power, noise):
    ## Note that we transpose the channel to
    diag_ch = np.diag(channel)
    inter_ch = channel-np.diag(diag_ch)
    tot_ch = np.multiply(channel, np.expand_dims(tx_power, -1))
    int_ch = np.multiply(inter_ch, np.expand_dims(tx_power, -1))
    sig_ch = np.sum(tot_ch-int_ch, axis=1)
    int_ch = np.sum(int_ch, axis=1)

    SINR_val = np.divide(sig_ch, int_ch+noise)
    cap_val = np.log(1.0+SINR_val)
    return cap_val



def cal_CUE_INTER_one_sample_one_channel(channel, tx_power):
    ## Note that we transpose the channel to
    diag_ch = np.diag(channel)
    inter_ch = channel-np.diag(diag_ch)
    int_ch = np.multiply(inter_ch, np.expand_dims(tx_power, -1))
    int_ch = np.sum(int_ch, axis=1)

    return int_ch




'''
Find all possible combination

How to use it?

>> optimal_power(channel, granuity of transmit power, noise, CUE_thr)

CUE_thr: Determine the minimum data rate required for CUE
DUE_thr: Determine the minimum data rate required for D2D user

Returned value is the sum of all D2D users


'''

'''
Find all possible combination

How to use it?

>> optimal_power(channel, granuity of transmit power, noise, CUE_thr)

CUE_thr: Determine the minimum data rate required for CUE
DUE_thr: Determine the minimum data rate required for D2D user

Returned value is the sum of all D2D users


'''



def cal_rate_NP(channel, tx_power_in, tx_max, noise, DUE_thr, I_thr, P_c):
    num_sample = channel.shape[0]
    num_channel = channel.shape[1]

    ## Note that the num_user i
    num_D2D_user = channel.shape[2] - 1

    tot_SE = 0
    tot_EE = 0
    tot_PW = 0

    tot_cap_CUE_vio = 0
    tot_cap_DUE_vio = 0
    tot_cap_CUE_vio_num = 0.001
    tot_cap_DUE_vio_num = 0.001

    DUE_violation = 0
    CUE_violation = 0

    # tot_success_prob counts the number of successful samples
    tot_success_prob = 0
    tx_power = np.hstack((tx_power_in, tx_max * np.ones((tx_power_in.shape[0], 1, num_channel))))

    for i in range(num_sample):
        cur_cap = 0
        DUE_mask = 1
        CUE_mask = 1

        for j in range(num_channel):

            cur_ch = channel[i][j]

            cur_power = tx_power[i, :, j]
            cur_power = np.array([cur_power])

            cur_ch_cap = cal_SINR_one_sample_one_channel(cur_ch, cur_power, noise)
            inter = cal_CUE_INTER_one_sample_one_channel(cur_ch, cur_power)

            cur_cap = cur_cap + cur_ch_cap[0]
            CUE_mask = CUE_mask * (inter[0, num_D2D_user] > I_thr)
            if inter[0, num_D2D_user] > I_thr:
                tot_cap_CUE_vio = tot_cap_CUE_vio + cur_ch_cap[0, num_D2D_user]
                tot_cap_CUE_vio_num = tot_cap_CUE_vio_num + 1


        for j in range(num_D2D_user):
            DUE_mask = DUE_mask * (cur_cap[j] > DUE_thr)
            if cur_cap[j] < DUE_thr:
                tot_cap_DUE_vio = tot_cap_DUE_vio + cur_cap[j]
                tot_cap_DUE_vio_num = tot_cap_DUE_vio_num + 1

        summed_tx_power = np.sum(tx_power[i, :, :], axis=1) + P_c
        D2D_SE_sum = np.sum(cur_cap[:-1])
        D2D_EE_sum = np.sum(cur_cap[:-1]/summed_tx_power[:-1])
        D2D_PW_sum = np.sum(np.sum(tx_power[i, :, :], axis=1)[:-1])


        if np.sum(CUE_mask) == 0:
            CUE_violation = CUE_violation + 1

        if np.sum(DUE_mask) == 0:
            DUE_violation = DUE_violation + 1


        tot_SE = tot_SE + D2D_SE_sum
        tot_EE = tot_EE + D2D_EE_sum
        tot_PW = tot_PW + D2D_PW_sum



    tot_SE = tot_SE / num_D2D_user / num_sample
    tot_EE = tot_EE / num_D2D_user / num_sample
    tot_PW = tot_PW / num_D2D_user / num_sample
    PRO_DUE_vio = tot_cap_DUE_vio_num / (num_sample * num_D2D_user)
    PRO_CUE_vio = tot_cap_CUE_vio_num / (num_sample * num_channel)

    return tot_SE, tot_EE, tot_PW, PRO_CUE_vio, PRO_DUE_vio




'''
    Find all possible combination
    Use product to find combination
'''


def all_possible_tx_power(num_channel, num_user, granuty):

    items = [np.arange(granuty)] * num_user * num_channel
    temp_power = list(itertools.product(*items))
    temp_power = np.reshape(temp_power, (-1, num_user, num_channel))

    power_check = np.sum(temp_power, axis=2)
    flag = (power_check / (granuty - 1) <= 1).astype(int)
    flag = (np.sum(flag, axis=1) / num_user == 1).astype(int)
    flag = np.reshape(flag, (-1, 1))
    temp_power_1 = np.reshape(temp_power, (-1, num_user * num_channel))
    temp_power = temp_power_1 * flag
    power = np.reshape(temp_power, (-1, num_user, num_channel)) / (granuty - 1)

    power_mat = []
    for i in range(power.shape[0]):
        sum_val = np.sum(power[i])
        if sum_val != 0:
            power_mat.append(power[i])

    return np.array(power_mat)




def WMMSE_sum_rate(p_int, H, Pmax, int_cell):
    K = np.size(p_int)
    vnew = 0
    b = np.sqrt(p_int)
    f = np.zeros(K)
    w = np.zeros(K)
    for i in range(K):
        f[i] = H[i, i] * b[i] / ( np.matmul(np.square(H[i, :]), np.square(b)) + int_cell[i])
        w[i] = 1 / (1 - f[i] * b[i] * H[i, i])
        vnew = vnew + np.log(w[i])

    VV = np.zeros(100)
    for iter in range(100):
        vold = vnew
        for i in range(K):
            btmp = w[i] * f[i] * H[i, i] / sum(w * np.square(f) * np.square(H[:, i]))
            b[i] = np.minimum(btmp, np.sqrt(Pmax)) + np.maximum(btmp, 0) - btmp
        vnew = 0
        for i in range(K):
            f[i] = H[i, i] * b[i] / ( np.matmul(np.square(H[i, :]), np.square(b) ) + int_cell[i])
            w[i] = 1 / (1 - f[i] * b[i] * H[i, i] + 1e-12)
            vnew = vnew + np.log(w[i])
        VV[iter] = vnew
        if vnew - vold <= 1e-3:
            break
    p_opt = np.square(b)
    return p_opt



'''
Find all possible combination

How to use it?

>> optimal_power(channel, granuity of transmit power, noise, CUE_thr)

CUE_thr: Determine the minimum data rate required for CUE
DUE_thr: Determine the minimum data rate required for D2D user

Returned value is the sum of all D2D users AND POWER val


'''


def optimal_power(channel, tx_max, granuty, noise, DUE_thr, I_thr, P_c, tx_power_set):

    num_channel = channel.shape[1]
    ## Note that the num_user i
    num_D2D_user = channel.shape[2] - 1
    num_samples = channel.shape[0]
    num_success = 0

    tot_SE = 0
    tot_EE = 0
    tot_PW = 0

    power_mat_SE = []
    power_mat_EE = []
    power_mat_PW = []

    feasible_channel_mat = []
    infeasible_channel_mat = []

    # tot_success_prob counts the number of successful samples
    tot_success_prob = 0

    ### IMPORTANT: 2019-02-17
    ### Granuity should be change to include minus 1
    tx_power = tx_power_set

    tx_power = tx_max * np.hstack((tx_power, np.ones((tx_power.shape[0], 1, num_channel))))

    for i in range(num_samples):
        cur_cap = 0
        DUE_mask = 1
        CUE_mask = 1

        if i%10000 == 0:
            print(i)


        for j in range(num_channel):
            cur_ch = channel[i][j]
            cur_ch_cap = cal_SINR_one_sample_one_channel(cur_ch, tx_power[:, :, j], noise)
            inter = cal_CUE_INTER_one_sample_one_channel(cur_ch, tx_power[:, :, j])
            cur_cap = cur_cap + cur_ch_cap
            CUE_mask = CUE_mask * (inter[:, num_D2D_user] < I_thr)

        for j in range(num_D2D_user):
            DUE_mask = DUE_mask * (cur_cap[:, j] > DUE_thr)



        summed_tx_power = np.sum(tx_power, axis=2) + P_c
        SE_sum = np.sum(cur_cap[:, :-1], axis=1)
        EE_sum = np.sum(cur_cap[:, :-1] / summed_tx_power[:, :-1], axis=1)
        PW_sum = np.sum(np.sum(tx_power, axis=2)[:, :-1], axis=1)

        SE_sum = SE_sum * CUE_mask * DUE_mask
        EE_sum = EE_sum * CUE_mask * DUE_mask
        PW_sum = PW_sum * (1/(CUE_mask+1e-10)) * (1/(DUE_mask+1e-10))

        D2D_SE_arg = np.argmax(SE_sum)
        D2D_EE_arg = np.argmax(EE_sum)
        D2D_PW_arg = np.argmin(PW_sum[1:])


        max_SE = np.max(SE_sum)
        max_EE = np.max(EE_sum)
        min_PW = np.min(PW_sum[1:])


        if (max_SE) > 0 & (max_EE >0) & (min_PW > 0):
            feasible_channel_mat.append(channel[i])
            power_mat_EE.append(tx_power[D2D_EE_arg][:-1])
            power_mat_SE.append(tx_power[D2D_SE_arg][:-1])
            power_mat_PW.append(tx_power[D2D_PW_arg][:-1])
            num_success = num_success + 1
            tot_SE = tot_SE + max_SE
            tot_EE = tot_EE + max_EE
            tot_PW = tot_PW + min_PW

        else:
            infeasible_channel_mat.append(channel[i])




    tot_SE = tot_SE / num_success / num_D2D_user
    tot_EE = tot_EE / num_success / num_D2D_user
    tot_PW = tot_PW / num_success / num_D2D_user


    return tot_SE, tot_EE, tot_PW, np.array(power_mat_SE), np.array(power_mat_EE), np.array(power_mat_PW), np.array(feasible_channel_mat)





'''
    This function is used to divide feasible channels and infeasible channels
'''


def optimal_power_check_valid(channel, tx_max, granuty, noise, DUE_thr, I_thr, tx_power_set):
    num_channel = channel.shape[1]
    ## Note that the num_user i
    num_D2D_user = channel.shape[2] - 1
    num_samples = channel.shape[0]
    ## Feasible
    feasible_channel_mat = []
    infeasible_channel_mat = []

    # tot_success_prob counts the number of successful samples
    tx_power = tx_power_set
    tx_power = tx_max * np.hstack((tx_power, np.ones((tx_power.shape[0], 1, num_channel))))


    for i in range(num_samples):
        if i%10 == 0:
            print(i)
        cur_cap = 0
        DUE_mask = 1
        CUE_mask = 1
        for j in range(num_channel):
            cur_ch = channel[i][j]
            cur_ch_cap = cal_SINR_one_sample_one_channel(cur_ch, tx_power[:, :, j], noise)
            inter = cal_CUE_INTER_one_sample_one_channel(cur_ch, tx_power[:, :, j])
            cur_cap = cur_cap + cur_ch_cap
            CUE_mask = CUE_mask * (inter[:, num_D2D_user] > I_thr)
        for j in range(num_D2D_user):
            DUE_mask = DUE_mask * (cur_cap[:, j] > DUE_thr)
        D2D_sum = np.sum(cur_cap[:, :-1], axis=1)
        D2D_sum = D2D_sum * CUE_mask * DUE_mask
        max_cap = np.max(D2D_sum)

        if max_cap > 0:
            feasible_channel_mat.append(channel[i])
        else:
            infeasible_channel_mat.append(channel[i])

    return np.array(feasible_channel_mat), np.array(infeasible_channel_mat)


'''
    This function is used to divide feasible channels and infeasible channels
'''
def convert_optimal_power(tx_power_mat, tx_max, Num_power_level, Num_channel):
    num_samples = tx_power_mat.shape[0]
    num_user = tx_power_mat.shape[1]
    resource_alloc = []
    for i in range(num_samples):
        resource_alloc_inner = []
        for j in range(num_user):
            channel_select = np.argmax(tx_power_mat[i, j])
            power_select = np.round(tx_power_mat[i, j, channel_select] / tx_max * (Num_power_level - 1))
            resource_alloc_mat = np.zeros((Num_power_level + Num_channel,))
            resource_alloc_mat[int(power_select)] = 1
            resource_alloc_mat[int(Num_power_level + channel_select)] = 1
            resource_alloc_inner.append(resource_alloc_mat)
        resource_alloc.append(np.array(resource_alloc_inner))
    return np.array(resource_alloc)


'''
Find random power
'''


def random_power(channel, tx_max, granuty, noise, DUE_thr, I_thr, P_c):
    num_sample = channel.shape[0]
    num_channel = channel.shape[1]

    ## Note that the num_user i
    num_D2D_user = channel.shape[2] - 1
    tot_SE = 0
    tot_EE = 0
    tot_PW = 0

    tot_cap_CUE_vio = 0
    tot_cap_DUE_vio = 0
    tot_cap_CUE_vio_num = 0.001
    tot_cap_DUE_vio_num = 0.001

    DUE_violation = 0
    CUE_violation = 0


    for i in range(num_sample):
        cur_cap = 0
        DUE_mask = 1
        CUE_mask = 1

        tx_power = np.random.random((num_D2D_user+1, num_channel))
        tx_power = tx_power / np.reshape(np.sum(tx_power, axis=1), (-1, 1))



        tx_power = tx_max * tx_power * np.random.random((num_D2D_user+1, 1))
        tx_power[num_D2D_user, :] = tx_max



        for j in range(num_channel):
            cur_ch = channel[i][j]
            tx_power_select = np.reshape(tx_power[:, j], (1, -1))
            cur_ch_cap = cal_SINR_one_sample_one_channel(cur_ch, tx_power_select, noise)
            inter = cal_CUE_INTER_one_sample_one_channel(cur_ch, tx_power_select)



            cur_cap = cur_cap + cur_ch_cap[0]
            CUE_mask = CUE_mask * (inter[0, num_D2D_user] > I_thr)
            if inter[0, num_D2D_user] > I_thr:
                tot_cap_CUE_vio = tot_cap_CUE_vio + cur_ch_cap[0, num_D2D_user]
                tot_cap_CUE_vio_num = tot_cap_CUE_vio_num + 1

        for j in range(num_D2D_user):
            DUE_mask = DUE_mask * (cur_cap[j] > DUE_thr)
            if cur_cap[j] < DUE_thr:
                tot_cap_DUE_vio = tot_cap_DUE_vio + cur_cap[j]
                tot_cap_DUE_vio_num = tot_cap_DUE_vio_num + 1

        D2D_sum = np.sum(cur_cap[:-1])
        CUE_sum = np.sum(cur_cap[-1])
        D2D_sum_filter = D2D_sum * CUE_mask * DUE_mask

        summed_tx_power = np.sum(tx_power, axis=1) + P_c
        SE_sum = np.sum(cur_cap[:-1])
        EE_sum = np.sum(cur_cap[:-1]/summed_tx_power[:-1])
        PW_sum = np.sum(np.sum(tx_power, axis=1)[:-1])

        print("SE_sum = ", SE_sum)
        print("EE_sum = ", EE_sum)

        if np.sum(CUE_mask) == 0:
            CUE_violation = CUE_violation + 1

        if np.sum(DUE_mask) == 0:
            DUE_violation = DUE_violation + 1


        tot_SE = tot_SE + SE_sum
        tot_EE = tot_EE + EE_sum
        tot_PW = tot_PW + PW_sum

    tot_SE = tot_SE / num_D2D_user / num_sample
    tot_EE = tot_EE / num_D2D_user / num_sample
    tot_PW = tot_PW / num_D2D_user / num_sample
    PRO_DUE_vio = tot_cap_DUE_vio_num / (num_sample * num_D2D_user)
    PRO_CUE_vio = tot_cap_CUE_vio_num / (num_sample * num_channel)

    return tot_SE, tot_EE, tot_PW, PRO_CUE_vio, PRO_DUE_vio





'''

    LOSS MODELS

'''


'''
    This function calculates the capacity of D2D users and CUE user
    The return of this function is D2D capacity and CUE capacity
    Capacity for each channel is acculmulated.

    The shape of return is as follows:

    [Num_sample, Num_users]

    The output will be capacity of D2D and capacity of CUE

'''







def cal_RATE_tf(channel, tx_power, tx_max, noise, num_samples, log_data_mean, log_data_std):
    chan_num = channel.shape[1]
    user_num = channel.shape[2]
    cap_val = tf.constant(0.0)
    CUE_cap = []

    channel_rev = tf.exp(channel * log_data_std + log_data_mean)

    for i in range(chan_num):
        tx_power_w_CUE = tf.concat([tx_power[:, :, i], tx_max * tf.ones((num_samples, 1))], axis=1)
        tot_ch = tf.multiply(channel_rev[:, i], tf.expand_dims(tx_power_w_CUE, -1))
        sig_ch = tf.linalg.diag_part(tot_ch)
        inter_ch = tot_ch - tf.linalg.diag(sig_ch)
        inter_ch = tf.reduce_sum(inter_ch, axis=1)
        SINR_val = tf.div(sig_ch, inter_ch + noise)
        cap_val = cap_val + tf.log(tf.constant(1.0) + SINR_val)
        CUE_cap.append(tf.log(tf.constant(1.0) + SINR_val)[:, -1])

    cap_val_D2D = cap_val[:, :user_num - 1]
    return cap_val_D2D, tf.transpose(tf.convert_to_tensor(CUE_cap))



def cal_inter_tf(channel, tx_power, tx_max, noise, num_samples, log_data_mean, log_data_std):
    chan_num = channel.shape[1]
    user_num = channel.shape[2]
    cap_val = tf.constant(0.0)
    Inter_cap = []

    channel_rev = tf.exp(channel * log_data_std + log_data_mean)

    for i in range(chan_num):
        tx_power_w_CUE = tf.concat([tx_power[:, :, i], tx_max * tf.ones((num_samples, 1))], axis=1)
        tot_ch = tf.multiply(channel_rev[:, i], tf.expand_dims(tx_power_w_CUE, -1))
        sig_ch = tf.linalg.diag_part(tot_ch)
        inter_ch = tot_ch - tf.linalg.diag(sig_ch)
        inter_ch = tf.reduce_sum(inter_ch, axis=1)
        Inter_cap.append(inter_ch[:, -1])

    return tf.transpose(tf.convert_to_tensor(Inter_cap))





def cal_EE_tf(cap_val_D2D, tx_power, p_c):
    tx_power_d2d = tf.reduce_sum(tx_power, axis=2)
    EE = tf.div(cap_val_D2D, p_c+tx_power_d2d)
    return EE

def cal_EE_tf_temp(cap_val_D2D, tx_power, p_c):
    tx_power_d2d = tf.reduce_sum(tx_power, axis=2)
    EE = tf.div(1.0, p_c+tx_power_d2d)
    return EE



def cal_LOSS_Total_SE_tf(channel, tf_output, noise, DUE_thr, I_thr, tx_max, num_samples, log_data_mean, log_data_std,
                      lambda_mat):
    ## Output of DNN should be divided proerly.
    ## [power alloc, resource alloc on channel]
    chan_num = int(channel.shape[1])
    user_num = int(channel.shape[2] - 1)
    power_granu = int(Num_power_level)
    tx_pow_chan = tf.minimum(tf_output, 1.0)*tx_max

    D2D_rate, CUE_rate = cal_RATE_tf(channel, tx_pow_chan, tx_max, noise, num_samples, log_data_mean, log_data_std)
    I_val = cal_inter_tf(channel, tx_pow_chan, tx_max, noise, num_samples, log_data_mean, log_data_std)
    D2D_vio = tf.nn.tanh(tf.nn.relu(DUE_thr-D2D_rate) / (DUE_thr + 1e-10))
    CUE_vio = tf.nn.tanh(tf.nn.relu(I_val - I_thr) / (I_thr + 1e-10))
    D2D_vio_sum = tf.reduce_mean(D2D_vio)
    CUE_vio_sum = tf.reduce_mean(CUE_vio)
    Loss = -lambda_mat[0] * tf.reduce_mean(D2D_rate, axis=1) + lambda_mat[1] * D2D_vio_sum + lambda_mat[2] * CUE_vio_sum
    return Loss



def cal_LOSS_Total_EE_tf(channel, tf_output, noise, DUE_thr, I_thr, tx_max, num_samples, log_data_mean, log_data_std,
                      lambda_mat, P_c):
    ## Output of DNN should be divided proerly.
    ## [power alloc, resource alloc on channel]
    chan_num = int(channel.shape[1])
    user_num = int(channel.shape[2] - 1)
    power_granu = int(Num_power_level)
    tx_pow_chan = tf.minimum(tf_output, 1.0) * tx_max
    D2D_rate, CUE_rate = cal_RATE_tf(channel, tx_pow_chan, tx_max, noise, num_samples, log_data_mean, log_data_std)
    EE_rate = cal_EE_tf(D2D_rate, tx_pow_chan, P_c)
    I_val = cal_inter_tf(channel, tx_pow_chan, tx_max, noise, num_samples, log_data_mean, log_data_std)
    D2D_vio = tf.nn.tanh(tf.nn.relu(DUE_thr-D2D_rate) / (DUE_thr + 1e-10))
    CUE_vio = tf.nn.tanh(tf.nn.relu(I_val - I_thr) / (I_thr + 1e-10))
    D2D_vio_sum = tf.reduce_mean(D2D_vio)
    CUE_vio_sum = tf.reduce_mean(CUE_vio)
    Loss = -1000.0*lambda_mat[0] * tf.reduce_mean(EE_rate, axis=1) + lambda_mat[1] * D2D_vio_sum + lambda_mat[2] * CUE_vio_sum
    #Loss = -100.0 * lambda_mat[0] * tf.reduce_mean(EE_rate, axis=1)
    return Loss



def cal_LOSS_Total_PW_tf(channel, tf_output, noise, DUE_thr, I_thr, tx_max, num_samples, log_data_mean, log_data_std,
                      lambda_mat):
    chan_num = int(channel.shape[1])
    user_num = int(channel.shape[2] - 1)
    power_granu = int(Num_power_level)
    tx_pow_chan = tf.minimum(tf_output, 1.0) * tx_max
    D2D_rate, CUE_rate = cal_RATE_tf(channel, tx_pow_chan, tx_max, noise, num_samples, log_data_mean, log_data_std)
    I_val = cal_inter_tf(channel, tx_pow_chan, tx_max, noise, num_samples, log_data_mean, log_data_std)
    D2D_vio = tf.nn.tanh(tf.nn.relu(DUE_thr-D2D_rate) / (DUE_thr + 1e-10))
    CUE_vio = tf.nn.tanh(tf.nn.relu(I_val - I_thr) / (I_thr + 1e-10))
    D2D_vio_sum = tf.reduce_mean(D2D_vio)
    CUE_vio_sum = tf.reduce_mean(CUE_vio)
    Loss = lambda_mat[0] * tf.reduce_mean(tx_pow_chan, axis=1) + lambda_mat[1] * D2D_vio_sum + lambda_mat[2] * CUE_vio_sum
    #Loss = -100.0 * lambda_mat[0] * tf.reduce_mean(EE_rate, axis=1)
    return Loss



def cal_LOSS_rate_tf(channel, tf_output, noise, tx_max, num_samples, log_data_mean, log_data_std):
    tx_pow_chan = tf_output*tx_max
    D2D_rate, CUE_rate = cal_RATE_tf(channel, tx_pow_chan, tx_max, noise, num_samples, log_data_mean, log_data_std)
    Loss = -tf.reduce_mean(D2D_rate, axis=1)
    return Loss




def cal_LOSS_ee_tf(channel, tf_output, noise, tx_max, num_samples, log_data_mean, log_data_std, P_c):
    tx_pow_chan = tf_output * tx_max
    D2D_rate, CUE_rate = cal_RATE_tf(channel, tx_pow_chan, tx_max, noise, num_samples, log_data_mean, log_data_std)
    EE_rate = cal_EE_tf(D2D_rate, tx_pow_chan, P_c)
    Loss = -tf.reduce_mean(EE_rate, axis=1)
    return Loss


def cal_LOSS_ee_tf_temp(channel, tf_output, noise, tx_max, num_samples, log_data_mean, log_data_std, P_c):
    tx_pow_chan = tf_output * tx_max
    D2D_rate, CUE_rate = cal_RATE_tf(channel, tx_pow_chan, tx_max, noise, num_samples, log_data_mean, log_data_std)
    EE_rate = cal_EE_tf_temp(D2D_rate, tx_pow_chan, P_c)
    Loss = -tf.reduce_mean(EE_rate, axis=1)
    return Loss


def cal_LOSS_pw_tf_temp(channel, tf_output, noise, tx_max, num_samples, log_data_mean, log_data_std):
    tx_pow_chan = tf_output * tx_max
    Loss = tf.reduce_mean(tx_pow_chan, axis=1)
    return Loss





def cal_LOSS_DUE_CONST_tf(channel, tf_output, noise, DUE_thr, CUE_thr, tx_max, num_samples, log_data_mean, log_data_std,
                      lambda_mat):
    ## Output of DNN should be divided proerly.
    ## [power alloc, resource alloc on channel]
    chan_num = int(channel.shape[1])
    user_num = int(channel.shape[2] - 1)
    tx_pow_chan = tf_output*tx_max

    D2D_rate, CUE_rate = cal_RATE_tf(channel, tx_pow_chan, tx_max, noise, num_samples, log_data_mean, log_data_std)
    D2D_vio = tf.cast(DUE_thr>D2D_rate, tf.float32)
    D2D_vio_sum = tf.reduce_mean(D2D_vio, axis=1)
    Loss = D2D_vio_sum

    return Loss




def cal_LOSS_CUE_CONST_tf(channel, tf_output, noise, DUE_thr, I_thr, tx_max, num_samples, log_data_mean, log_data_std,
                      lambda_mat):
    ## Output of DNN should be divided proerly.
    ## [power alloc, resource alloc on channel]
    chan_num = int(channel.shape[1])
    user_num = int(channel.shape[2] - 1)
    power_granu = int(Num_power_level)
    tx_pow_chan = tf_output*tx_max

    I_val = cal_inter_tf(channel, tx_pow_chan, tx_max, noise, num_samples, log_data_mean, log_data_std)
    CUE_vio = tf.cast(I_val>I_thr, tf.float32)
    CUE_vio_sum = tf.reduce_mean(CUE_vio, axis=1)
    Loss = CUE_vio_sum
    return Loss





'''
    This function calculates the loss for RATE
'''



def cal_LOSS_init_tf(channel, tf_output, y_true):
    Loss = tf.reduce_mean(tf.reduce_mean(tf.pow(tf_output-y_true, 2), axis=2))
    return Loss



def Total_SE_loss_wrapper(input_tensor, noise, DUE_thr, I_thr, tx_max, num_sample, log_data_mean, log_data_std, lambda_mat):
    def TOTAL_SE_loss(y_true, y_pred):
        Loss = cal_LOSS_Total_SE_tf(input_tensor, y_pred, noise, DUE_thr, I_thr, tx_max, num_sample, log_data_mean, log_data_std, lambda_mat)
        return Loss
    return TOTAL_SE_loss


def Total_EE_loss_wrapper(input_tensor, noise, DUE_thr, I_thr, tx_max, num_sample, log_data_mean, log_data_std, lambda_mat, P_c):
    def TOTAL_EE_loss(y_true, y_pred):
        Loss = cal_LOSS_Total_EE_tf(input_tensor, y_pred, noise, DUE_thr, I_thr, tx_max, num_sample, log_data_mean, log_data_std, lambda_mat, P_c)
        return Loss
    return TOTAL_EE_loss


def Total_PW_loss_wrapper(input_tensor, noise, DUE_thr, I_thr, tx_max, num_sample, log_data_mean, log_data_std, lambda_mat):
    def TOTAL_PW_loss(y_true, y_pred):
        Loss = cal_LOSS_Total_PW_tf(input_tensor, y_pred, noise, DUE_thr, I_thr, tx_max, num_sample, log_data_mean, log_data_std, lambda_mat)
        return Loss
    return TOTAL_PW_loss




def Rate_loss_wrapper(input_tensor, noise,  tx_max, num_sample, log_data_mean, log_data_std):
    def RATE_loss(y_true, y_pred):
        Loss = cal_LOSS_rate_tf(input_tensor, y_pred, noise, tx_max, num_sample, log_data_mean, log_data_std)
        return Loss
    return RATE_loss



def EE_loss_wrapper(input_tensor, noise, tx_max, num_sample, log_data_mean, log_data_std, P_c):
    def EE_loss(y_true, y_pred):
        Loss = cal_LOSS_ee_tf(input_tensor, y_pred, noise, tx_max, num_sample, log_data_mean, log_data_std, P_c)
        return Loss
    return EE_loss




def PW_loss_wrapper(input_tensor, noise, tx_max, num_sample, log_data_mean, log_data_std):
    def PW_loss(y_true, y_pred):
        Loss = cal_LOSS_pw_tf_temp(input_tensor, y_pred, noise, tx_max, num_sample, log_data_mean, log_data_std)
        return Loss
    return PW_loss





def INT_CONST_loss_wrapper(input_tensor):
    def INT_loss(y_true, y_pred):
        chan_num = int(input_tensor.shape[1])
        rounded_value = tf.round(y_pred)
        Loss = tf.reduce_mean(tf.reduce_mean(tf.square(y_pred-rounded_value), axis=2), axis=1)
        return Loss
    return INT_loss


def RA_CONST_loss_wrapper(input_tensor):
    def RA_loss(y_true, y_pred):
        chan_num = int(input_tensor.shape[1])
        rounded_value = tf.round(y_pred)
        Loss = tf.reduce_mean(tf.reduce_mean(tf.square(y_pred[:,:,-chan_num:]-rounded_value[:,:,-chan_num:]), axis=2), axis=1)
        return Loss
    return RA_loss


def DUE_CONST_loss_wrapper(input_tensor, noise, DUE_thr, I_thr, tx_max, num_sample, log_data_mean, log_data_std, lambda_mat):
    def DUE_loss(y_true, y_pred):
        Loss = cal_LOSS_DUE_CONST_tf(input_tensor, y_pred, noise, DUE_thr, I_thr, tx_max, num_sample, log_data_mean,
                                log_data_std, lambda_mat)
        return Loss
    return DUE_loss

def CUE_CONST_loss_wrapper(input_tensor, noise, DUE_thr, I_thr, tx_max, num_sample, log_data_mean, log_data_std, lambda_mat):
    def CUE_loss(y_true, y_pred):
        Loss = cal_LOSS_CUE_CONST_tf(input_tensor, y_pred, noise, DUE_thr, I_thr, tx_max, num_sample, log_data_mean,
                                log_data_std, lambda_mat)
        return Loss
    return CUE_loss




def Init_loss_wrapper(input_tensor):
    def Init_loss(y_true, y_pred):
        Loss = cal_LOSS_init_tf(input_tensor, y_pred, y_true)
        return Loss
    return Init_loss




'''

    DNN MODELS

'''


def DNN_basic_module(Input_layer, Num_weights_inner, Num_outputs, Num_layers=3, activation='relu'):
    Inner_layer = layers.Dense(Num_weights_inner)(Input_layer)
    Inner_layer_in = layers.Activation('relu')(Inner_layer)

    ## Number of layers should be at least 2
    assert Num_layers > 1

    for i in range(Num_layers - 2):
        Inner_layer_in = layers.Dense(Num_weights_inner)(Inner_layer_in)
        Inner_layer_in = layers.Activation('relu')(Inner_layer_in)
        Inner_layer_in = layers.Dropout(0.1)(Inner_layer_in)

    Out_layer = layers.Dense(Num_outputs)(Inner_layer_in)
    return Out_layer


def DNN_basic_module_super(Input_layer, Num_weights_inner, Num_outputs, Num_layers=3, activation='relu'):
    Inner_layer = layers.Dense(Num_weights_inner)(Input_layer)
    Inner_layer_in = layers.Activation('relu')(Inner_layer)

    ## Number of layers should be at least 2
    assert Num_layers > 1

    for i in range(Num_layers - 2):
        Inner_layer_in = layers.Dense(Num_weights_inner)(Inner_layer_in)
        Inner_layer_in = layers.Activation('relu')(Inner_layer_in)

    Out_layer = layers.Dense(Num_outputs)(Inner_layer_in)
    return Out_layer



"""
    Construct model with full CSI
"""


def DNN_model_full(Num_channel, Num_user, Num_weights, Num_layers=4):
    inputs = tf.keras.Input(shape=(Num_channel, Num_user + 1, Num_user + 1))
    inputs_reshape = layers.Flatten(input_shape=(Num_channel, Num_user + 1, Num_user + 1))(inputs)

    ## Find the results for Power level
    result_PL = DNN_basic_module(inputs_reshape, Num_weights, Num_user, Num_layers)
    result_PL = layers.Reshape((Num_user, 1))(result_PL)

    result_PL = layers.Activation('sigmoid')(result_PL)


    ## Find the results for Resourace allocation
    result_RA = DNN_basic_module(inputs_reshape, Num_weights, Num_user * Num_channel, Num_layers)
    result_RA = layers.Reshape((Num_user, Num_channel))(result_RA)
    result_RA = layers.Activation('softmax')(result_RA)

    result = layers.Multiply()([result_PL, result_RA])

    model = tf.keras.Model(inputs=inputs, outputs=result)

    return model




def DNN_model_full_super(Num_channel, Num_user, Num_weights, Num_layers=3):
    inputs = tf.keras.Input(shape=(Num_channel, Num_user + 1, Num_user + 1))
    inputs_reshape = layers.Flatten(input_shape=(Num_channel, Num_user + 1, Num_user + 1))(inputs)

    ## Find the results for Resourace allocation
    result = DNN_basic_module_super(inputs_reshape, Num_weights, Num_user * Num_channel, Num_layers)
    result = layers.Reshape((Num_user, Num_channel))(result)
    result = layers.Activation('relu')(result)

    model = tf.keras.Model(inputs=inputs, outputs=result)

    return model






def Print_DNN_OUT(model, log_data, real_data, tx_max, noise, DUE_thr, CUE_thr, P_c):
    DNN_out_pre = tx_max*np.minimum(model.predict(log_data), 1.0)
    tot_SE, tot_EE, tot_PW, PRO_CUE_vio, PRO_DUE_vio = cal_rate_NP(real_data, DNN_out_pre, tx_max, noise, DUE_thr, CUE_thr, P_c)
    print("SE = %0.2f  EE = %0.2f PW = %0.2f, Vio(CUE) = %0.2f  Vio(DUE) = %0.2f"%(tot_SE*1.44, tot_EE*1.44*1000, tot_PW, PRO_CUE_vio*100, PRO_DUE_vio*100))
    return tot_SE, tot_EE, tot_PW





def Print_Test_Full(model_SE, model_EE,  model_PW, model_SE_super, model_EE_super, model_PW_super, log_data_test, data_test, Num_power_level,
                    tx_max, noise, DUE_thr, I_thr, P_c, tx_power_set, SE_OPT, EE_OPT, PW_OPT):

    SE_MAT = []
    EE_MAT = []
    PW_MAT = []

    ### Full DNN case
    #SE_OPT, EE_OPT, PW_OPT, _, _, _, _= optimal_power(data_test, tx_max, Num_power_level, noise, DUE_thr, I_thr, P_c, tx_power_set)

    SE_MAT.append(SE_OPT)
    EE_MAT.append(EE_OPT)
    PW_MAT.append(PW_OPT)


    print("Opt: SE = %0.2f, EE = %0.2f, PW = %0.2f" % (SE_OPT * 1.44, EE_OPT * 1.44 * 1000, PW_OPT))
    print("")

    print("")
    print("Proposed Scheme")
    print("")

    ### DNN-SE case
    print("DNN (SE): ", end='')
    DNN_SE_SE, DNN_SE_EE, DNN_SE_PW = Print_DNN_OUT(model_SE, log_data_test, data_test, tx_max, noise, DUE_thr, I_thr, P_c)

    SE_MAT.append(DNN_SE_SE)
    EE_MAT.append(DNN_SE_EE)
    PW_MAT.append(DNN_SE_PW)


    ### DNN-EE case
    print("DNN (EE): ", end='')
    DNN_EE_SE, DNN_EE_EE, DNN_EE_PW = Print_DNN_OUT(model_EE, log_data_test, data_test, tx_max, noise, DUE_thr, I_thr, P_c)

    SE_MAT.append(DNN_EE_SE)
    EE_MAT.append(DNN_EE_EE)
    PW_MAT.append(DNN_EE_PW)


    ### DNN-PW case
    print("DNN (PW): ", end='')
    DNN_PW_SE, DNN_PW_EE, DNN_PW_PW = Print_DNN_OUT(model_PW, log_data_test, data_test, tx_max, noise, DUE_thr, I_thr, P_c)

    SE_MAT.append(DNN_PW_SE)
    EE_MAT.append(DNN_PW_EE)
    PW_MAT.append(DNN_PW_PW)


    print("")
    print("SUPERVISER_LEARNING")
    print("")
    ### DNN-SE case (SUPERVISED LEARNING)
    print("DNN - SUPERVISED LEARNINGN (SE): ", end='')
    DNN_SE_SL_SE, DNN_SE_SL_EE, DNN_SE_SL_PW = Print_DNN_OUT(model_SE_super, log_data_test, data_test, tx_max, noise, DUE_thr, I_thr, P_c)

    SE_MAT.append(DNN_SE_SL_SE)
    EE_MAT.append(DNN_SE_SL_EE)
    PW_MAT.append(DNN_SE_SL_PW)


    print("DNN - SUPERVISED LEARNINGN (EE): ", end='')
    DNN_EE_SL_SE, DNN_EE_SL_EE, DNN_EE_SL_PW = Print_DNN_OUT(model_EE_super, log_data_test, data_test, tx_max, noise, DUE_thr, I_thr, P_c)

    SE_MAT.append(DNN_EE_SL_SE)
    EE_MAT.append(DNN_EE_SL_EE)
    PW_MAT.append(DNN_EE_SL_PW)


    print("DNN - SUPERVISED LEARNINGN (PW): ", end='')
    DNN_PW_SL_SE, DNN_PW_SL_EE, DNN_PW_SL_PW = Print_DNN_OUT(model_PW_super, log_data_test, data_test, tx_max, noise, DUE_thr, I_thr, P_c)

    SE_MAT.append(DNN_PW_SL_SE)
    EE_MAT.append(DNN_PW_SL_EE)
    PW_MAT.append(DNN_PW_SL_PW)




    ### Random case
    print("")
    RAN_SE, RAN_EE, RAN_PW, PRO_CUE_vio, PRO_DUE_vio = random_power(data_test, tx_max, Num_power_level, noise, DUE_thr, I_thr, P_c, tx_power_set)
    print("Random case: ", end='')
    print("SE = %0.2f  EE = %0.2f PW = %0.2f, Vio(CUE) = %0.2f  Vio(DUE) = %0.2f"%(RAN_SE*1.44, RAN_EE*1.44*1000, RAN_PW, PRO_CUE_vio*100, PRO_DUE_vio*100))

    SE_MAT.append(RAN_SE)
    EE_MAT.append(RAN_EE)
    PW_MAT.append(RAN_PW)

    print("")
    print("")

    return SE_MAT, EE_MAT, PW_MAT




####################################################################################
####################################################################################
####################################################################################
####################################################################################


Num_user = 3
Num_channel = 3
Num_power_level = 5
Num_layers_full = 4
Num_weights_full = 100


BW = 1e7
noise = BW*10**-17.4
num_samples_init = int(2*1e4)
num_samples_test = int(1*1e4)
Size_area = 30
D2D_dist = 15
batch_size_set = 50

DUE_thr = 3.0/1.44
tx_max = 10**2.0
I_thr = 10**(-50.0/10)
P_c = 10**2.0
epoch_num = 2000



SE_MAT_TOT = []
EE_MAT_TOT = []
PW_MAT_TOT = []



for outer_loop in range(1):
    Size_area = 20.0 + 5 * outer_loop + 15

    tx_power_set = all_possible_tx_power(Num_channel, Num_user, Num_power_level - 1)
    print(tx_power_set.shape)

    data_train_full = np.array(ch_gen(Size_area, D2D_dist, Num_user, Num_channel, num_samples_init))
    SE_OPT, EE_OPT, PW_OPT, SE_PW_MAT, EE_PW_MAT, PW_PW_MAT, data_train = optimal_power(data_train_full, tx_max, Num_power_level, noise, DUE_thr, I_thr, P_c, tx_power_set)


    print("Number of feasible sets:  ", data_train.shape)
    print("Opt: SE = %0.2f, EE = %0.2f, PW = %0.2f" % (SE_OPT * 1.44, EE_OPT * 1.44 * 1000, PW_OPT))
    print("")

    data_train = data_train[:batch_size_set * (data_train.shape[0] // batch_size_set)]
    SE_PW_MAT = SE_PW_MAT[:batch_size_set * (data_train.shape[0] // batch_size_set)]
    EE_PW_MAT = EE_PW_MAT[:batch_size_set * (data_train.shape[0] // batch_size_set)]
    PW_PW_MAT = PW_PW_MAT[:batch_size_set * (data_train.shape[0] // batch_size_set)]


    ## Recalculate the number of feasible solutions
    num_samples = data_train.shape[0]
    labels = np.zeros(num_samples, )
    log_data = np.log(data_train)
    log_data_mean = np.mean(log_data)
    log_data_std = np.std(log_data)
    log_data = (log_data - log_data_mean) / log_data_std


    # print(log_data.shape)
    # aa = np.random.randn(10000, 2, 3, 3)
    # print(aa)
    # print("")




    print("")
    print("Outer_loop: %d "%outer_loop)

    learning_rate_cur = 1e-4
    lambda_mat = np.ones((3, 1)) * 40 * (outer_loop+2)
    lambda_mat[0] = 1
    lambda_mat[2] = lambda_mat[2] * 0.2


    lambda_ee_mat = np.ones((3, 1)) * 2000 * (outer_loop+2)
    lambda_ee_mat[0] = 0.25
    lambda_ee_mat[2] = lambda_ee_mat[2] * 1.2

    lambda_pw_mat = np.ones((3, 1)) * 1000.0 * (outer_loop+2)
    lambda_pw_mat[0] = 0.1
    lambda_pw_mat[1] = 2000.0



    cap_full = []
    cap_dist = []
    prob_full = []
    prob_dist = []

    ######################################################################################################
    model_SE = DNN_model_full(Num_channel, Num_user, Num_weights_full, Num_layers_full)
    model_SE.compile(optimizer=tf.train.AdamOptimizer(learning_rate_cur),
                     loss=Total_SE_loss_wrapper(model_SE.input, noise, DUE_thr, I_thr, tx_max,
                                                batch_size_set,
                                                log_data_mean, log_data_std, lambda_mat),
                     metrics=[
                         Rate_loss_wrapper(model_SE.input, noise, tx_max, batch_size_set,
                                           log_data_mean, log_data_std),
                         EE_loss_wrapper(model_SE.input, noise, tx_max, batch_size_set,
                                         log_data_mean, log_data_std, P_c),
                         PW_loss_wrapper(model_SE.input, noise, tx_max, batch_size_set,
                                         log_data_mean, log_data_std),
                         CUE_CONST_loss_wrapper(model_SE.input, noise, DUE_thr, I_thr, tx_max,
                                                batch_size_set,
                                                log_data_mean, log_data_std, lambda_mat),
                         DUE_CONST_loss_wrapper(model_SE.input, noise, DUE_thr, I_thr, tx_max,
                                                batch_size_set,
                                                log_data_mean, log_data_std, lambda_mat)
                     ])

    ######################################################################################################
    ######################################################################################################

    model_EE = DNN_model_full(Num_channel, Num_user, Num_weights_full, Num_layers_full)
    model_EE.compile(optimizer=tf.train.AdamOptimizer(learning_rate_cur),
                     loss=Total_EE_loss_wrapper(model_EE.input, noise, DUE_thr, I_thr, tx_max,
                                                batch_size_set,
                                                log_data_mean, log_data_std, lambda_ee_mat, P_c),
                     metrics=[
                         Rate_loss_wrapper(model_EE.input, noise, tx_max, batch_size_set,
                                           log_data_mean, log_data_std),
                         EE_loss_wrapper(model_EE.input, noise, tx_max, batch_size_set,
                                         log_data_mean, log_data_std, P_c),
                         PW_loss_wrapper(model_EE.input, noise, tx_max, batch_size_set,
                                         log_data_mean, log_data_std),
                         CUE_CONST_loss_wrapper(model_EE.input, noise, DUE_thr, I_thr, tx_max,
                                                batch_size_set,
                                                log_data_mean, log_data_std, lambda_ee_mat),
                         DUE_CONST_loss_wrapper(model_EE.input, noise, DUE_thr, I_thr, tx_max,
                                                batch_size_set,
                                                log_data_mean, log_data_std, lambda_ee_mat)
                     ])

    ######################################################################################################
    ######################################################################################################

    model_PW = DNN_model_full(Num_channel, Num_user, Num_weights_full, Num_layers_full)
    model_PW.compile(optimizer=tf.train.AdamOptimizer(learning_rate_cur),
                     loss=Total_PW_loss_wrapper(model_PW.input, noise, DUE_thr, I_thr, tx_max,
                                                batch_size_set,
                                                log_data_mean, log_data_std, lambda_pw_mat),
                     metrics=[
                         Rate_loss_wrapper(model_PW.input, noise, tx_max, batch_size_set,
                                           log_data_mean, log_data_std),
                         EE_loss_wrapper(model_PW.input, noise, tx_max, batch_size_set,
                                         log_data_mean, log_data_std, P_c),
                         PW_loss_wrapper(model_PW.input, noise, tx_max, batch_size_set,
                                         log_data_mean, log_data_std),
                         CUE_CONST_loss_wrapper(model_PW.input, noise, DUE_thr, I_thr, tx_max,
                                                batch_size_set,
                                                log_data_mean, log_data_std, lambda_pw_mat),
                         DUE_CONST_loss_wrapper(model_PW.input, noise, DUE_thr, I_thr, tx_max,
                                                batch_size_set,
                                                log_data_mean, log_data_std, lambda_pw_mat)
                     ])

    ######################################################################################################


    ######################################################################################################
    model_SE_super = DNN_model_full_super(Num_channel, Num_user, Num_weights_full, Num_layers_full)
    model_SE_super.compile(optimizer=tf.train.RMSPropOptimizer(1e-4),
                     loss=Init_loss_wrapper(model_SE_super.input),
                     metrics=[Init_loss_wrapper(model_SE_super.input),
                              Rate_loss_wrapper(model_SE_super.input, noise, tx_max, batch_size_set,
                                                log_data_mean, log_data_std),
                              EE_loss_wrapper(model_SE_super.input, noise, tx_max, batch_size_set,
                                              log_data_mean, log_data_std, P_c),
                              PW_loss_wrapper(model_SE_super.input, noise, tx_max, batch_size_set,
                                              log_data_mean, log_data_std),
                              CUE_CONST_loss_wrapper(model_SE_super.input, noise, DUE_thr, I_thr, tx_max,
                                                     batch_size_set,
                                                     log_data_mean, log_data_std, lambda_pw_mat),
                              DUE_CONST_loss_wrapper(model_SE_super.input, noise, DUE_thr, I_thr, tx_max,
                                                     batch_size_set,
                                                     log_data_mean, log_data_std, lambda_pw_mat)
                     ])

    ######################################################################################################
    model_EE_super = DNN_model_full_super(Num_channel, Num_user, Num_weights_full, Num_layers_full)
    model_EE_super.compile(optimizer=tf.train.RMSPropOptimizer(1e-4),
                     loss=Init_loss_wrapper(model_EE_super.input),
                     metrics=[Init_loss_wrapper(model_EE_super.input),
                              Rate_loss_wrapper(model_EE_super.input, noise, tx_max, batch_size_set,
                                                log_data_mean, log_data_std),
                              EE_loss_wrapper(model_EE_super.input, noise, tx_max, batch_size_set,
                                              log_data_mean, log_data_std, P_c),
                              PW_loss_wrapper(model_EE_super.input, noise, tx_max, batch_size_set,
                                              log_data_mean, log_data_std),
                              CUE_CONST_loss_wrapper(model_EE_super.input, noise, DUE_thr, I_thr, tx_max,
                                                     batch_size_set,
                                                     log_data_mean, log_data_std, lambda_pw_mat),
                              DUE_CONST_loss_wrapper(model_EE_super.input, noise, DUE_thr, I_thr, tx_max,
                                                     batch_size_set,
                                                     log_data_mean, log_data_std, lambda_pw_mat)
                     ])


    ######################################################################################################
    model_PW_super = DNN_model_full_super(Num_channel, Num_user, Num_weights_full, Num_layers_full)
    model_PW_super.compile(optimizer=tf.train.RMSPropOptimizer(1e-4),
                     loss=Init_loss_wrapper(model_PW_super.input),
                     metrics=[Init_loss_wrapper(model_PW_super.input),
                              Rate_loss_wrapper(model_PW_super.input, noise, tx_max, batch_size_set,
                                                log_data_mean, log_data_std),
                              EE_loss_wrapper(model_PW_super.input, noise, tx_max, batch_size_set,
                                              log_data_mean, log_data_std, P_c),
                              PW_loss_wrapper(model_PW_super.input, noise, tx_max, batch_size_set,
                                              log_data_mean, log_data_std),
                              CUE_CONST_loss_wrapper(model_PW_super.input, noise, DUE_thr, I_thr, tx_max,
                                                     batch_size_set,
                                                     log_data_mean, log_data_std, lambda_pw_mat),
                              DUE_CONST_loss_wrapper(model_PW_super.input, noise, DUE_thr, I_thr, tx_max,
                                                     batch_size_set,
                                                     log_data_mean, log_data_std, lambda_pw_mat)
                     ])




    for i in range(1):


        #model_SE.fit(log_data, labels, batch_size=batch_size_set, epochs=epoch_num, verbose=0)
        print()
        print("Training of SE is finished")
        print()
        model_EE.fit(log_data, labels, batch_size=batch_size_set, epochs=epoch_num, verbose=0)
        print()
        print("Training of EE is finished")
        print()
        model_PW.fit(log_data, labels, batch_size=batch_size_set, epochs=epoch_num, verbose=0)
        print()
        print("Training of PW is finished")
        print()
        #model_SE_super.fit(log_data, SE_PW_MAT / tx_max, batch_size=batch_size_set, epochs=epoch_num, verbose=0)
        print()
        print("Training of SE (SUPERVISED) is finished")
        print()
        #model_EE_super.fit(log_data, EE_PW_MAT / tx_max, batch_size=batch_size_set, epochs=epoch_num, verbose=0)
        print()
        print("Training of EE (SUPERVISED) is finished")
        print()
        #model_PW_super.fit(log_data, PW_PW_MAT / tx_max, batch_size=batch_size_set, epochs=epoch_num, verbose=0)
        print()
        print("Training of PW (SUPERVISED) is finished")
        print()




        print("%d-th iteration is finished  " % i)






    ####################################################################
    ## Generate the test channel set
    ## Only the valid channel is used for test
    data_test_full = np.array(ch_gen(Size_area, D2D_dist, Num_user, Num_channel, num_samples_test))

    SE_OPT_test, EE_OPT_test, PW_OPT_test, _, _, _, data_test = optimal_power(data_test_full, tx_max, Num_power_level, noise, DUE_thr, I_thr, P_c, tx_power_set)


    log_data_test = np.log(data_test)
    log_data_test = (log_data_test - log_data_mean) / log_data_std

    ####################################################################

    print("test_set result")
    SE_val, EE_val, PW_val = Print_Test_Full(model_SE, model_EE, model_PW, model_SE_super, model_EE_super, model_PW_super, log_data_test, data_test, Num_power_level, tx_max, noise, DUE_thr, I_thr, P_c, tx_power_set, SE_OPT_test, EE_OPT_test, PW_OPT_test)






    SE_MAT_TOT.append(SE_val)
    EE_MAT_TOT.append(EE_val)
    PW_MAT_TOT.append(PW_val)




print("")
print("")
print("")
print("SE: ")
print(np.array(SE_MAT_TOT)*1.44)
print("")
print("EE: ")
print(np.array(EE_MAT_TOT)*1.44*1000)
print("")
print("PW: ")
print(np.array(PW_MAT_TOT))
print("")
